{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transformers 설치 방법\n",
    "! pip install transformers datasets evaluate accelerate\n",
    "# 마지막 릴리스 대신 소스에서 설치하려면, 위 명령을 주석으로 바꾸고 아래 명령을 해제하세요.\n",
    "# ! pip install git+https://github.com/huggingface/transformers.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 제로샷(zero-shot) 객체 탐지[[zeroshot-object-detection]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "일반적으로 [객체 탐지](https://huggingface.co/docs/transformers/main/ko/tasks/object_detection)에 사용되는 모델을 학습하기 위해서는 레이블이 지정된 이미지 데이터 세트가 필요합니다.\n",
    "그리고 학습 데이터에 존재하는 클래스(레이블)만 탐지할 수 있다는 한계점이 있습니다.\n",
    "\n",
    "다른 방식을 사용하는 [OWL-ViT](https://huggingface.co/docs/transformers/main/ko/tasks/../model_doc/owlvit) 모델로 제로샷 객체 탐지가 가능합니다.\n",
    "OWL-ViT는 개방형 어휘(open-vocabulary) 객체 탐지기입니다.\n",
    "즉, 레이블이 지정된 데이터 세트에 미세 조정하지 않고 자유 텍스트 쿼리를 기반으로 이미지에서 객체를 탐지할 수 있습니다.\n",
    "\n",
    "OWL-ViT 모델은 멀티 모달 표현을 활용해 개방형 어휘 탐지(open-vocabulary detection)를 수행합니다.\n",
    "[CLIP](https://huggingface.co/docs/transformers/main/ko/tasks/../model_doc/clip) 모델에 경량화(lightweight)된 객체 분류와 지역화(localization) 헤드를 결합합니다.\n",
    "개방형 어휘 탐지는 CLIP의 텍스트 인코더로 free-text 쿼리를 임베딩하고, 객체 분류와 지역화 헤드의 입력으로 사용합니다.\n",
    "이미지와 해당 텍스트 설명을 연결하면 ViT가 이미지 패치(image patches)를 입력으로 처리합니다.\n",
    "OWL-ViT 모델의 저자들은 CLIP 모델을 처음부터 학습(scratch learning)한 후에, bipartite matching loss를 사용하여 표준 객체 인식 데이터셋으로 OWL-ViT 모델을 미세 조정했습니다.\n",
    "\n",
    "이 접근 방식을 사용하면 모델은 레이블이 지정된 데이터 세트에 대한 사전 학습 없이도 텍스트 설명을 기반으로 객체를 탐지할 수 있습니다.\n",
    "\n",
    "이번 가이드에서는 OWL-ViT 모델의 사용법을 다룰 것입니다:\n",
    "- 텍스트 프롬프트 기반 객체 탐지\n",
    "- 일괄 객체 탐지\n",
    "- 이미지 가이드 객체 탐지\n",
    "\n",
    "시작하기 전에 필요한 라이브러리가 모두 설치되어 있는지 확인하세요:\n",
    "```bash\n",
    "pip install -q transformers\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 제로샷(zero-shot) 객체 탐지 파이프라인[[zeroshot-object-detection-pipeline]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`pipeline()`을 활용하면 가장 간단하게 OWL-ViT 모델을 추론해볼 수 있습니다.\n",
    "[Hugging Face Hub에 업로드된 체크포인트](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&sort=downloads)에서 제로샷(zero-shot) 객체 탐지용 파이프라인을 인스턴스화합니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "checkpoint = \"google/owlvit-base-patch32\"\n",
    "detector = pipeline(model=checkpoint, task=\"zero-shot-object-detection\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "다음으로, 객체를 탐지하고 싶은 이미지를 선택하세요.\n",
    "여기서는 [NASA](https://www.nasa.gov/multimedia/imagegallery/index.html) Great Images 데이터 세트의 일부인 우주비행사 에일린 콜린스(Eileen Collins) 사진을 사용하겠습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import skimage\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "image = skimage.data.astronaut()\n",
    "image = Image.fromarray(np.uint8(image)).convert(\"RGB\")\n",
    "\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_1.png\" alt=\"Astronaut Eileen Collins\"/>\n",
    "</div>\n",
    "\n",
    "이미지와 해당 이미지의 후보 레이블을 파이프라인으로 전달합니다.\n",
    "여기서는 이미지를 직접 전달하지만, 컴퓨터에 저장된 이미지의 경로나 url로 전달할 수도 있습니다.\n",
    "candidate_labels는 이 예시처럼 간단한 단어일 수도 있고 좀 더 설명적인 단어일 수도 있습니다.\n",
    "또한, 이미지를 검색(query)하려는 모든 항목에 대한 텍스트 설명도 전달합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.3571370542049408,\n",
       "  'label': 'human face',\n",
       "  'box': {'xmin': 180, 'ymin': 71, 'xmax': 271, 'ymax': 178}},\n",
       " {'score': 0.28099656105041504,\n",
       "  'label': 'nasa badge',\n",
       "  'box': {'xmin': 129, 'ymin': 348, 'xmax': 206, 'ymax': 427}},\n",
       " {'score': 0.2110239565372467,\n",
       "  'label': 'rocket',\n",
       "  'box': {'xmin': 350, 'ymin': -1, 'xmax': 468, 'ymax': 288}},\n",
       " {'score': 0.13790413737297058,\n",
       "  'label': 'star-spangled banner',\n",
       "  'box': {'xmin': 1, 'ymin': 1, 'xmax': 105, 'ymax': 509}},\n",
       " {'score': 0.11950037628412247,\n",
       "  'label': 'nasa badge',\n",
       "  'box': {'xmin': 277, 'ymin': 338, 'xmax': 327, 'ymax': 380}},\n",
       " {'score': 0.10649408400058746,\n",
       "  'label': 'rocket',\n",
       "  'box': {'xmin': 358, 'ymin': 64, 'xmax': 424, 'ymax': 280}}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = detector(\n",
    "    image,\n",
    "    candidate_labels=[\"human face\", \"rocket\", \"nasa badge\", \"star-spangled banner\"],\n",
    ")\n",
    "predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이제 예측값을 시각화해봅시다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import ImageDraw\n",
    "\n",
    "draw = ImageDraw.Draw(image)\n",
    "\n",
    "for prediction in predictions:\n",
    "    box = prediction[\"box\"]\n",
    "    label = prediction[\"label\"]\n",
    "    score = prediction[\"score\"]\n",
    "\n",
    "    xmin, ymin, xmax, ymax = box.values()\n",
    "    draw.rectangle((xmin, ymin, xmax, ymax), outline=\"red\", width=1)\n",
    "    draw.text((xmin, ymin), f\"{label}: {round(score,2)}\", fill=\"white\")\n",
    "\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_2.png\" alt=\"Visualized predictions on NASA image\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 텍스트 프롬프트 기반 객체 탐지[[textprompted-zeroshot-object-detection-by-hand]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "제로샷 객체 탐지 파이프라인 사용법에 대해 살펴보았으니, 이제 동일한 결과를 복제해보겠습니다.\n",
    "\n",
    "[Hugging Face Hub에 업로드된 체크포인트](https://huggingface.co/models?other=owlvit)에서 관련 모델과 프로세서를 가져오는 것으로 시작합니다.\n",
    "여기서는 이전과 동일한 체크포인트를 사용하겠습니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection\n",
    "\n",
    "model = AutoModelForZeroShotObjectDetection.from_pretrained(checkpoint)\n",
    "processor = AutoProcessor.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "다른 이미지를 사용해 보겠습니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "\n",
    "url = \"https://unsplash.com/photos/oj0zeY2Ltk4/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MTR8fHBpY25pY3xlbnwwfHx8fDE2Nzc0OTE1NDk&force=true&w=640\"\n",
    "im = Image.open(requests.get(url, stream=True).raw)\n",
    "im"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_3.png\" alt=\"Beach photo\"/>\n",
    "</div>\n",
    "\n",
    "프로세서를 사용해 모델의 입력을 준비합니다.\n",
    "프로세서는 모델의 입력으로 사용하기 위해 이미지 크기를 변환하고 정규화하는 이미지 프로세서와 텍스트 입력을 처리하는 [CLIPTokenizer](https://huggingface.co/docs/transformers/main/ko/model_doc/clip#transformers.CLIPTokenizer)로 구성됩니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_queries = [\"hat\", \"book\", \"sunglasses\", \"camera\"]\n",
    "inputs = processor(text=text_queries, images=im, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "모델에 입력을 전달하고 결과를 후처리 및 시각화합니다.\n",
    "이미지 프로세서가 모델에 이미지를 입력하기 전에 이미지 크기를 조정했기 때문에, `post_process_object_detection()` 메소드를 사용해\n",
    "예측값의 바운딩 박스(bounding box)가 원본 이미지의 좌표와 상대적으로 동일한지 확인해야 합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model(**inputs)\n",
    "    target_sizes = torch.tensor([im.size[::-1]])\n",
    "    results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]\n",
    "\n",
    "draw = ImageDraw.Draw(im)\n",
    "\n",
    "scores = results[\"scores\"].tolist()\n",
    "labels = results[\"labels\"].tolist()\n",
    "boxes = results[\"boxes\"].tolist()\n",
    "\n",
    "for box, score, label in zip(boxes, scores, labels):\n",
    "    xmin, ymin, xmax, ymax = box\n",
    "    draw.rectangle((xmin, ymin, xmax, ymax), outline=\"red\", width=1)\n",
    "    draw.text((xmin, ymin), f\"{text_queries[label]}: {round(score,2)}\", fill=\"white\")\n",
    "\n",
    "im"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_4.png\" alt=\"Beach photo with detected objects\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 일괄 처리[[batch-processing]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "여러 이미지와 텍스트 쿼리를 전달하여 여러 이미지에서 서로 다른(또는 동일한) 객체를 검색할 수 있습니다.\n",
    "일괄 처리를 위해서 텍스트 쿼리는 이중 리스트로, 이미지는 PIL 이미지, PyTorch 텐서, 또는 NumPy 배열로 이루어진 리스트로 프로세서에 전달해야 합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images = [image, im]\n",
    "text_queries = [\n",
    "    [\"human face\", \"rocket\", \"nasa badge\", \"star-spangled banner\"],\n",
    "    [\"hat\", \"book\", \"sunglasses\", \"camera\"],\n",
    "]\n",
    "inputs = processor(text=text_queries, images=images, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이전에는 후처리를 위해 단일 이미지의 크기를 텐서로 전달했지만, 튜플을 전달할 수 있고, 여러 이미지를 처리하는 경우에는 튜플로 이루어진 리스트를 전달할 수도 있습니다.\n",
    "아래 두 예제에 대한 예측을 생성하고, 두 번째 이미지(`image_idx = 1`)를 시각화해 보겠습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    outputs = model(**inputs)\n",
    "    target_sizes = [x.size[::-1] for x in images]\n",
    "    results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)\n",
    "\n",
    "image_idx = 1\n",
    "draw = ImageDraw.Draw(images[image_idx])\n",
    "\n",
    "scores = results[image_idx][\"scores\"].tolist()\n",
    "labels = results[image_idx][\"labels\"].tolist()\n",
    "boxes = results[image_idx][\"boxes\"].tolist()\n",
    "\n",
    "for box, score, label in zip(boxes, scores, labels):\n",
    "    xmin, ymin, xmax, ymax = box\n",
    "    draw.rectangle((xmin, ymin, xmax, ymax), outline=\"red\", width=1)\n",
    "    draw.text((xmin, ymin), f\"{text_queries[image_idx][label]}: {round(score,2)}\", fill=\"white\")\n",
    "\n",
    "images[image_idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_4.png\" alt=\"Beach photo with detected objects\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 이미지 가이드 객체 탐지[[imageguided-object-detection]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "텍스트 쿼리를 이용한 제로샷 객체 탐지 외에도 OWL-ViT 모델은 이미지 가이드 객체 탐지 기능을 제공합니다.\n",
    "이미지를 쿼리로 사용해 대상 이미지에서 유사한 객체를 찾을 수 있다는 의미입니다.\n",
    "텍스트 쿼리와 달리 하나의 예제 이미지에서만 가능합니다.\n",
    "\n",
    "소파에 고양이 두 마리가 있는 이미지를 대상 이미지(target image)로, 고양이 한 마리가 있는 이미지를 쿼리로 사용해보겠습니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
    "image_target = Image.open(requests.get(url, stream=True).raw)\n",
    "\n",
    "query_url = \"http://images.cocodataset.org/val2017/000000524280.jpg\"\n",
    "query_image = Image.open(requests.get(query_url, stream=True).raw)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "다음 이미지를 살펴보겠습니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, ax = plt.subplots(1, 2)\n",
    "ax[0].imshow(image_target)\n",
    "ax[1].imshow(query_image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_5.png\" alt=\"Cats\"/>\n",
    "</div>\n",
    "\n",
    "전처리 단계에서 텍스트 쿼리 대신에 `query_images`를 사용합니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = processor(images=image_target, query_images=query_image, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "예측의 경우, 모델에 입력을 전달하는 대신 `image_guided_detection()`에 전달합니다.\n",
    "레이블이 없다는 점을 제외하면 이전과 동일합니다.\n",
    "이전과 동일하게 이미지를 시각화합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    outputs = model.image_guided_detection(**inputs)\n",
    "    target_sizes = torch.tensor([image_target.size[::-1]])\n",
    "    results = processor.post_process_image_guided_detection(outputs=outputs, target_sizes=target_sizes)[0]\n",
    "\n",
    "draw = ImageDraw.Draw(image_target)\n",
    "\n",
    "scores = results[\"scores\"].tolist()\n",
    "boxes = results[\"boxes\"].tolist()\n",
    "\n",
    "for box, score, label in zip(boxes, scores, labels):\n",
    "    xmin, ymin, xmax, ymax = box\n",
    "    draw.rectangle((xmin, ymin, xmax, ymax), outline=\"white\", width=4)\n",
    "\n",
    "image_target"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_6.png\" alt=\"Cats with bounding boxes\"/>\n",
    "</div>\n",
    "\n",
    "OWL-ViT 모델을 추론하고 싶다면 아래 데모를 확인하세요:\n",
    "\n",
    "<iframe\n",
    "\tsrc=\"https://adirik-owl-vit.hf.space\"\n",
    "\tframeborder=\"0\"\n",
    "\twidth=\"850\"\n",
    "\theight=\"450\"\n",
    "></iframe>"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
